{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 量化注解"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "::::{dropdown}\n",
    "```c++\n",
    "using namespace relay::transform;\n",
    "\n",
    "class QAnnotateExpr;\n",
    "class QAnnotateExprNode : public TempExprNode {\n",
    " public:\n",
    "  Expr expr;\n",
    "  QAnnotateKind kind;\n",
    "\n",
    "  void VisitAttrs(tvm::AttrVisitor* v) {\n",
    "    v->Visit(\"expr\", &expr);\n",
    "    v->Visit(\"kind\", &kind);\n",
    "  }\n",
    "\n",
    "  Expr Realize() const final;\n",
    "\n",
    "  static constexpr const char* _type_key = \"relay.QAnnotateExpr\";\n",
    "  TVM_DECLARE_FINAL_OBJECT_INFO(QAnnotateExprNode, TempExprNode);\n",
    "};\n",
    "\n",
    "class QAnnotateExpr : public TempExpr {\n",
    " public:\n",
    "  /*!\n",
    "   * \\brief The constructor\n",
    "   * \\param expr The original relay expression.\n",
    "   * \\param kind The annotation kind.\n",
    "   */\n",
    "  TVM_DLL QAnnotateExpr(Expr expr, QAnnotateKind kind);\n",
    "\n",
    "  TVM_DEFINE_OBJECT_REF_METHODS(QAnnotateExpr, TempExpr, QAnnotateExprNode);\n",
    "};\n",
    "\n",
    "Expr QAnnotateExprNode::Realize() const { return expr; }\n",
    "\n",
    "QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) {\n",
    "  auto rnode = make_object<QAnnotateExprNode>();\n",
    "  rnode->expr = std::move(expr);\n",
    "  rnode->kind = kind;\n",
    "  data_ = std::move(rnode);\n",
    "}\n",
    "\n",
    "TVM_REGISTER_GLOBAL(\"relay._quantize.make_annotate_expr\").set_body_typed([](Expr expr, int kind) {\n",
    "  return QAnnotateExpr(expr, static_cast<QAnnotateKind>(kind));\n",
    "});\n",
    "```\n",
    "::::"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码定义了 `QAnnotateExpr` 类，它继承自 `TempExpr`。这个类主要用于表示带有注解的表达式。其中，`QAnnotateExprNode` 是内部类，用于存储表达式和注解类型。`VisitAttrs` 方法用于访问表达式和注解类型的属性。`Realize` 方法返回原始表达式。\n",
    "\n",
    "`QAnnotateExpr` 类的构造函数接受一个表达式和一个注解类型作为参数，并将它们存储在 `QAnnotateExprNode` 对象中。`TVM_DEFINE_OBJECT_REF_METHODS` 宏用于定义对象的引用方法。\n",
    "\n",
    "最后，`TVM_REGISTER_GLOBAL` 宏用于注册全局函数，该函数接受一个表达式和一个整数类型的注解，并返回 `QAnnotateExpr` 对象。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "::::{dropdown}\n",
    "```c++\n",
    "Pass QuantizeAnnotate() {\n",
    "  // TODO(tvm-teams): since partition has added cast_hint in different\n",
    "  // branches, try to remove this in the future.\n",
    "  std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {\n",
    "    if (e->IsInstance<TempExprNode>()) {\n",
    "      const auto* n = e.as<QAnnotateExprNode>();\n",
    "      ICHECK(n);\n",
    "      const PackedFunc* f = runtime::Registry::Get(\"relay.quantize.attach_simulated_quantize\");\n",
    "      Expr ret = (*f)(n->expr, static_cast<int>(kQInput));\n",
    "      return static_cast<Expr>(QAnnotateExpr(ret, kQInput));\n",
    "    }\n",
    "    return e;\n",
    "  };\n",
    "\n",
    "  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =\n",
    "      [=](Function f, IRModule m, PassContext pc) {\n",
    "        auto func = Downcast<Function>(ForwardRewrite(f, \"FQAnnotateRewrite\", nullptr, fmulti_ref));\n",
    "        auto new_params = func->params;\n",
    "        for (const auto& x : FreeVars(func)) {\n",
    "          new_params.push_back(x);\n",
    "        }\n",
    "        return WithFields(func, new_params);\n",
    "      };\n",
    "  return CreateFunctionPass(pass_func, 1, \"QuantizeAnnotate\", {});\n",
    "}\n",
    "\n",
    "TVM_REGISTER_GLOBAL(\"relay._quantize.QuantizeAnnotate\").set_body_typed(QuantizeAnnotate);\n",
    "\n",
    "TVM_REGISTER_NODE_TYPE(QAnnotateExprNode);\n",
    "```\n",
    "::::"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码定义了 `QuantizeAnnotate` 函数，它的作用是对输入的函数进行量化注解。具体来说，它首先定义了名为 `fmulti_ref` 的 `lambda` 函数，该函数接受一个表达式作为参数，如果该表达式是 `TempExprNode` 的实例，则对其进行量化注解，否则直接返回原表达式。\n",
    "\n",
    "接下来，定义了一个名为 `pass_func` 的函数，它接受一个函数、一个 IR 模块和一个 PassContext 作为参数。在这个函数中，首先对输入的函数进行前向重写，然后遍历函数中的自由变量，将它们添加到新的参数列表中。最后，使用新的参数列表创建一个新的函数，并返回。\n",
    "\n",
    "最后，使用 `CreateFunctionPass` 创建一个函数传递，并将其注册为全局函数 `relay._quantize.QuantizeAnnotate`。同时，还注册了一个名为 `QAnnotateExprNode` 的节点类型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch\n",
    "\n",
    "\n",
    "class Model(nn.Module):\n",
    "    def __init__(self, *args, **kwargs) -> None:\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.conv = nn.Conv2d(3, 16, 3, 1, 1, bias=False)\n",
    "        self.conv2 = nn.Conv2d(16, 16, 3, 1, 1, bias=True)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv(x)\n",
    "        x1 = self.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x2 = self.relu(x)\n",
    "        x = x1 + x2\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 16, 4, 4), float32] {\n",
      "  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;\n",
      "  %1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;\n",
      "  %2 = add(%1, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
      "  %3 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;\n",
      "  %4 = nn.relu(%2) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;\n",
      "  add(%3, %4) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__add_0:0:0 */\n",
      "} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 16, 4, 4), float32] */\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tvm\n",
    "from tvm import relay\n",
    "\n",
    "# 输入数据\n",
    "input_shape = (1, 3, 4, 4)\n",
    "input_dtype = \"float32\"\n",
    "data_np = np.random.rand(*input_shape).astype(input_dtype)\n",
    "with torch.no_grad():\n",
    "    pt_model = Model().eval().float()\n",
    "    traced_model = torch.jit.trace(pt_model, torch.from_numpy(data_np)).eval()\n",
    "mod, params = relay.frontend.from_pytorch(traced_model, [(\"data\", input_shape)], \n",
    "                                          use_parser_friendly_name=True)\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    mod = relay.quantize.prerequisite_optimize(mod, params)\n",
    "print(mod['main'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 16, 4, 4), float32] {\n",
       "  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;\n",
       "  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;\n",
       "  %2 = annotation.cast_hint(%1, dtype=\"int8\") /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %3 = annotation.cast_hint(%0, dtype=\"int8\") /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %4 = annotation.stop_fusion(%3) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %5 = nn.conv2d(%4, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;\n",
       "  %6 = add(%5, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %7 = nn.relu(%6) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;\n",
       "  %8 = annotation.cast_hint(%7, dtype=\"int8\") /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %9 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %10 = annotation.stop_fusion(%8) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %11 = add(%9, %10) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__add_0:0:0 */;\n",
       "  %12 = annotation.cast_hint(%11, dtype=\"int8\") /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  annotation.stop_fusion(%12) /* ty=Tensor[(1, 16, 4, 4), float32] */\n",
       "} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 16, 4, 4), float32] */"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "relay.quantize.partition()(mod)[\"main\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */, %dom_scale: float32 /* ty=float32 */, %clip_min: float32 /* ty=float32 */, %clip_max: float32 /* ty=float32 */, %dom_scale1: float32 /* ty=float32 */, %clip_min1: float32 /* ty=float32 */, %clip_max1: float32 /* ty=float32 */, %dom_scale2: float32 /* ty=float32 */, %clip_min2: float32 /* ty=float32 */, %clip_max2: float32 /* ty=float32 */, %dom_scale3: float32 /* ty=float32 */, %clip_min3: float32 /* ty=float32 */, %clip_max3: float32 /* ty=float32 */, %dom_scale4: float32 /* ty=float32 */, %clip_min4: float32 /* ty=float32 */, %clip_max4: float32 /* ty=float32 */, %dom_scale5: float32 /* ty=float32 */, %clip_min5: float32 /* ty=float32 */, %clip_max5: float32 /* ty=float32 */, %dom_scale6: float32 /* ty=float32 */, %clip_min6: float32 /* ty=float32 */, %clip_max6: float32 /* ty=float32 */) -> Tensor[(1, 16, 4, 4), float32] {\n",
       "  %0 = relay.op.annotation.simulated_quantize(%data, %dom_scale, %clip_min, %clip_max, kind=1) /* ty=Tensor[(1, 3, 4, 4), float32] */;\n",
       "  %1 = relay.op.annotation.simulated_quantize(meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, %dom_scale1, %clip_min1, %clip_max1, kind=2) /* ty=Tensor[(16, 3, 3, 3), float32] */;\n",
       "  %2 = nn.conv2d(%0, %1, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;\n",
       "  %3 = relay.op.annotation.simulated_quantize(%2, %dom_scale2, %clip_min2, %clip_max2, kind=1) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;\n",
       "  %5 = annotation.cast_hint(%4, dtype=\"int8\") /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %6 = annotation.cast_hint(%3, dtype=\"int8\") /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %7 = annotation.stop_fusion(%6) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %8 = relay.op.annotation.simulated_quantize(meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, %dom_scale3, %clip_min3, %clip_max3, kind=2) /* ty=Tensor[(16, 16, 3, 3), float32] */;\n",
       "  %9 = nn.conv2d(%7, %8, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;\n",
       "  %10 = relay.op.annotation.simulated_quantize(meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */, %dom_scale4, %clip_min4, %clip_max4, kind=2) /* ty=Tensor[(16, 1, 1), float32] */;\n",
       "  %11 = add(%9, %10) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %12 = nn.relu(%11) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;\n",
       "  %13 = relay.op.annotation.simulated_quantize(%12, %dom_scale5, %clip_min5, %clip_max5, kind=1) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %14 = annotation.cast_hint(%13, dtype=\"int8\") /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %15 = annotation.stop_fusion(%5) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %16 = annotation.stop_fusion(%14) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %17 = add(%15, %16) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__add_0:0:0 */;\n",
       "  %18 = relay.op.annotation.simulated_quantize(%17, %dom_scale6, %clip_min6, %clip_max6, kind=1) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  %19 = annotation.cast_hint(%18, dtype=\"int8\") /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
       "  annotation.stop_fusion(%19) /* ty=Tensor[(1, 16, 4, 4), float32] */\n",
       "} /* ty=fn (Tensor[(1, 3, 4, 4), float32], float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32) -> Tensor[(1, 16, 4, 4), float32] */"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "passes = tvm.transform.Sequential([\n",
    "    relay.quantize.partition(),\n",
    "    relay.quantize.annotate()\n",
    "])\n",
    "passes(mod)[\"main\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #A2F\">@main</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>), int8] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>), int8] <span style=\"color: #A2F; font-weight: bold\">*/</span>) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>), int8] {\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">=</span> qnn<span style=\"color: #A2F; font-weight: bold\">.</span>dequantize(<span style=\"color: #A2F; font-weight: bold\">%</span>x, <span style=\"color: #008000\">0.1</span>f <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>float32 <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>int32 <span style=\"color: #A2F; font-weight: bold\">*/</span>, out_dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #A2F; font-weight: bold\">=</span> nn<span style=\"color: #A2F; font-weight: bold\">.</span>softmax(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, axis<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  qnn<span style=\"color: #A2F; font-weight: bold\">.</span>quantize(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">0.00390625</span>f <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>float32 <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">-</span><span style=\"color: #008000\">128</span> <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>int32 <span style=\"color: #A2F; font-weight: bold\">*/</span>, out_dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int8&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>), int8] <span style=\"color: #A2F; font-weight: bold\">*/</span>\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #A2F\">@main</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>), int8] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>), int8] <span style=\"color: #A2F; font-weight: bold\">*/</span>) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>), int8] {\n",
       "  qnn<span style=\"color: #A2F; font-weight: bold\">.</span>softmax(<span style=\"color: #A2F; font-weight: bold\">%</span>x, <span style=\"color: #008000\">0.1</span>f <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>float32 <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>int32 <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #008000\">0.00390625</span>f <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>float32 <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">-</span><span style=\"color: #008000\">128</span> <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>int32 <span style=\"color: #A2F; font-weight: bold\">*/</span>, axis<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>), int8] <span style=\"color: #A2F; font-weight: bold\">*/</span>\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "shape = [5, 10]\n",
    "scale = 0.1\n",
    "x_ = relay.var(\"x\", shape=shape, dtype=\"int8\")\n",
    "x = relay.qnn.op.dequantize(x_, relay.const(scale), relay.const(0))\n",
    "op = relay.op.nn.softmax(x, axis=1)\n",
    "op = relay.qnn.op.quantize(\n",
    "    op, relay.const(1.0 / 256.0), relay.const(-128), out_dtype=\"int8\"\n",
    ")\n",
    "\n",
    "x_np = np.random.randint(-128, 127, size=shape, dtype=\"int8\")\n",
    "x_np = np.sort(x_np)\n",
    "args = [x_np]\n",
    "\n",
    "mod = tvm.IRModule.from_expr(op)\n",
    "mod = tvm.relay.transform.InferType()(mod)\n",
    "mod_int = tvm.relay.transform.FakeQuantizationToInteger(\n",
    "    hard_fail=True, optional_qnn_ops=[\"nn.softmax\"]\n",
    ")(mod)\n",
    "mod.show()\n",
    "mod_int.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvm-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
